#  -*- coding: utf-8 -*-
#  Copyright (c) 2025-2025 Huawei Technologies Co., Ltd.
#  #
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#  #
#  http://www.apache.org/licenses/LICENSE-2.0
#  #
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.

"""
msmodelslim.utils.patch.torch 模块的单元测试
"""

import json
import os
from unittest.mock import patch, MagicMock, Mock
import pytest
from transformers import AutoModelForCausalLM, AutoTokenizer
from msmodelslim.utils.security.model import (
    SafeGenerator,
    InvalidModelError,
    get_valid_read_path,
    exception_handler,
)


def create_mock_model_dir(tmp_path):
    """创建模拟的预训练模型目录（含config、tokenizer、模型权重文件）"""
    model_dir = tmp_path / "mock_model"
    model_dir.mkdir(exist_ok=True)

    config_data = {"model_type": "gpt2", "vocab_size": 50257}
    with open(model_dir / "config.json", "w", encoding="utf-8") as f:
        json.dump(config_data, f)

    with open(model_dir / "vocab.txt", "w", encoding="utf-8") as f:
        f.write("<|endoftext|>\nhello\nworld")
    tokenizer_config = {"do_lower_case": False}
    with open(model_dir / "tokenizer_config.json", "w", encoding="utf-8") as f:
        json.dump(tokenizer_config, f)

    with open(model_dir / "pytorch_model.bin", "wb") as f:
        f.write(b"dummy_weight")

    return str(model_dir)


def create_mock_jsonl(tmp_path, file_name="test_dataset.jsonl", is_humaneval=False):
    """创建模拟的 JSONL 数据集文件"""
    jsonl_path = tmp_path / file_name
    lines = []
    if is_humaneval:
        lines = [
            json.dumps({"prompt": "def add(a, b): return ", "other_key": "val1"}),
            json.dumps({"prompt": "def multiply(x, y): return ", "other_key": "val2"}),
        ]
    else:
        lines = [
            json.dumps({"inputs_pretokenized": "test text 1", "other_key": "val1"}),
            json.dumps({"inputs_pretokenized": "test text 2", "other_key": "val2"}),
        ]
    with open(jsonl_path, "w", encoding="utf-8") as f:
        f.write("\n".join(lines))
    return str(jsonl_path)


class TestSafeGenerator:

    @staticmethod
    def test_load_jsonl_normal_success(mock_self, tmp_path):
        """测试加载普通 JSONL 数据集（默认 key: inputs_pretokenized）"""
        mock_jsonl_path = create_mock_jsonl(tmp_path, is_humaneval=False)

        dataset = mock_self.safe_gen.load_jsonl(mock_jsonl_path)

        assert len(dataset) == 2
        assert dataset == ["test text 1", "test text 2"]

    @staticmethod
    def test_load_jsonl_fail_invalid_json(mock_self, tmp_path):
        """测试加载格式错误的 JSONL 文件（某行不是合法 JSON）"""
        invalid_jsonl_path = tmp_path / "invalid.jsonl"
        with open(invalid_jsonl_path, "w", encoding="utf-8") as f:
            f.write('{"valid": "line"}\ninvalid json line')

        with pytest.raises(json.JSONDecodeError):
            mock_self.safe_gen.load_jsonl(str(invalid_jsonl_path))

    @staticmethod
    def test_load_jsonl_fail_file_not_exist(mock_self, tmp_path):
        """测试加载不存在的 JSONL 文件"""
        non_exist_path = str(tmp_path / "non_exist.jsonl")

        with pytest.raises(FileNotFoundError):
            mock_self.safe_gen.load_jsonl(non_exist_path)

    @pytest.fixture
    def mock_self(self):
        mock = Mock()
        mock.safe_gen = SafeGenerator()
        return mock

    @patch("msmodelslim.utils.security.model.get_valid_read_path")
    def test_get_config_from_pretrained_fail_missing_config(
        self, mock_get_valid, mock_self, tmp_path
    ):
        """测试加载缺失 config.json 的模型目录（异常场景）"""
        invalid_model_dir = tmp_path / "invalid_model"
        invalid_model_dir.mkdir(exist_ok=True)
        mock_get_valid.return_value = str(invalid_model_dir)

        with pytest.raises(InvalidModelError) as excinfo:
            mock_self.safe_gen.get_config_from_pretrained(str(invalid_model_dir))

        assert "Get config from pretrained failed" in str(excinfo.value)
        assert "ensure config files all exist" in str(excinfo.value)

    @patch("msmodelslim.utils.security.model.get_valid_read_path")
    def test_get_model_from_pretrained_success(
        self, mock_get_valid, mock_self, tmp_path
    ):
        """测试正常加载预训练模型"""
        mock_model_dir = create_mock_model_dir(tmp_path)
        mock_get_valid.return_value = mock_model_dir

        with patch.object(
            AutoModelForCausalLM, "from_pretrained"
        ) as mock_from_pretrained:
            mock_model = MagicMock()
            mock_from_pretrained.return_value = mock_model

            model = mock_self.safe_gen.get_model_from_pretrained(mock_model_dir)

            mock_from_pretrained.assert_called_once_with(
                mock_model_dir, local_files_only=True
            )
            assert model is mock_model

    @patch("msmodelslim.utils.security.model.get_valid_read_path")
    def test_get_model_from_pretrained_fail_corrupt_weight(
        self, mock_get_valid, mock_self, tmp_path
    ):
        """测试加载权重损坏的模型（异常场景）"""
        mock_model_dir = create_mock_model_dir(tmp_path)
        with open(
            os.path.join(mock_model_dir, "pytorch_model.bin"), "w", encoding="utf-8"
        ) as f:
            f.write("corrupt_data")
        mock_get_valid.return_value = mock_model_dir

        with pytest.raises(InvalidModelError) as excinfo:
            mock_self.safe_gen.get_model_from_pretrained(mock_model_dir)

        assert "Get model from pretrained failed" in str(excinfo.value)
        assert "model weights files all exist and are valid" in str(excinfo.value)

    @patch("msmodelslim.utils.security.model.get_valid_read_path")
    def test_get_tokenizer_from_pretrained_success(
        self, mock_get_valid, mock_self, tmp_path
    ):
        """测试正常加载分词器"""
        mock_model_dir = create_mock_model_dir(tmp_path)
        mock_get_valid.return_value = mock_model_dir

        with patch.object(AutoTokenizer, "from_pretrained") as mock_from_pretrained:
            mock_tokenizer = MagicMock()
            mock_from_pretrained.return_value = mock_tokenizer

            tokenizer = mock_self.safe_gen.get_tokenizer_from_pretrained(mock_model_dir)

            mock_from_pretrained.assert_called_once_with(
                mock_model_dir, local_files_only=True
            )
            assert tokenizer is mock_tokenizer

    @patch("msmodelslim.utils.security.model.get_valid_read_path")
    def test_get_tokenizer_from_pretrained_fail_missing_vocab(
        self, mock_get_valid, mock_self, tmp_path
    ):
        """测试加载缺失 vocab.txt 的分词器（异常场景）"""
        mock_model_dir = tmp_path / "no_vocab_model"
        mock_model_dir.mkdir(exist_ok=True)
        with open(mock_model_dir / "config.json", "w", encoding="utf-8") as f:
            json.dump({"model_type": "gpt2"}, f)
        mock_get_valid.return_value = str(mock_model_dir)

        with pytest.raises(InvalidModelError) as excinfo:
            mock_self.safe_gen.get_tokenizer_from_pretrained(str(mock_model_dir))

        assert "Get tokenizer from pretrained failed" in str(excinfo.value)
        assert "tokenizer files all exist and are valid" in str(excinfo.value)
